-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve memory management in clustering_qr.kmeans_plusplus #775
base: main
Are you sure you want to change the base?
Improve memory management in clustering_qr.kmeans_plusplus #775
Conversation
@RobertoDF Are you able to share the data that you're seeing this problem with so that I can test this myself? |
Sure, compressing the files now. |
In the zip there is a jupyter notebook that shows the problem and the specific Xd tensor that causes the crash on my machine. I put the standard and modified versions of kmeans_plusplus. The old one should crash, if you run the new one afterwards, it should run without errors. |
Just noticed that in the notebook I didn´t include the change at line |
@RobertoDF Those are not the files I would need. I mean the full recording, either a .bin file or whatever format you converted from, along with the probe file you used. |
…ing_qr.kmeans_plusplus
This last commit seems to really solve the OOM problems. |
Hello, I tried to use your last commit, but I'm still getting a CUDA OOM error in the final clustering phase. How much dedicated GPU memory do you have? I have 8 GB, and Kilosort used on average 6-7 GB throughout sorting until crashing at the end. |
I have 12 GB. Without the modification I would get OOM often inside the kmeans_plus_plus func. Which line is problematic to you exactly? and what is the error message saying? also what is your recording duration? |
Thanks for the quick response. Yes, kmeans_plus_plus inside of clustering_qr seems to be the cause of each crash every time. My recording duration is 90 min. Here's the problematic line and the kilosort log if it helps: File "C:\miniconda3\envs\kilosort\lib\site-packages\kilosort\clustering_qr.py", line 215, in kmeans_plusplus |
Mmm never had a crash at that line. If you use the normal version, not my fork, does it also crashes in the same line? |
Just ran another attempt with normal version. Here's the problem line: File "C:\Users\ColginLab\miniconda3\envs\kilosort\lib\site-packages\kilosort\clustering_qr.py", line 167, in kmeans_plusplus |
Ok that was a problematic line also for me and indeed I would expect my solution to solve that one. But I never had a problem at the line you showed me before. Maybe it can be optimized further but I won't have time to check this in near future. If you have access to a 12 GB I would expect that to solve the problem. |
Alright, I'll look into getting more GPU memory. Thanks for the help! |
@RobertoDF Are you able to provide a bit more explanation for the changes you proposed? I can see from other issues that they're helping with some memory problems, but I'm having a hard time finding any information in the Pytorch docs that would explain why these changes prevent copies / otherwise reduce memory usage. |
Sure! I just went in the code using a debugger breakpoint while checking GPU memory consumption and substitute (while checking the output to be identitical) lines until I would find a combination that would somehow avoid the unnecessary creation of large arrays on the GPU without sacrificing any speed (at least in my tests). Loads of trial and error!! |
This modification avoids the creation or immediately deletes unnecessary tensors in clustering_qr.kmeans_plusplus. It helps with OOM errors (#746 ) happening at
Kilosort/kilosort/clustering_qr.py
Line 202 in b2f5ded
Xg can at sometimes be quite big (5GB in the case I get OOM), in both of these lines a copy of Xg was created unnecessarily on the GPU.
Kilosort/kilosort/clustering_qr.py
Lines 166 to 168 in b2f5ded
&
Kilosort/kilosort/clustering_qr.py
Line 202 in b2f5ded
The solution to line 202 does not impact speed. Solution to line 167 might impact speed but not in any noticeable fashion on my tests, for this reason I didn´t extend the reach of the clear_cache arg to the kmeans_plusplus func.
Tested on pytorch 2.1.2 and 2.4.1.